import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates

# ------------------------
# 1. 读取数据
# ------------------------
df = pd.read_excel("Both Data.xlsx", engine="openpyxl")

# 确保 MonthYear 是字符串，方便处理
df['MonthYear'] = df['MonthYear'].astype(str)

# 解析 MonthYear 为 pandas Period（月度）便于排序和绘图
df['Month'] = pd.to_datetime(df['MonthYear'], format='%Y%m', errors='coerce').dt.to_period('M')

# 标记 CT/TC 事件（按你原代码逻辑）
is_ct = (df['Actor1CountryCode'] == 'CHN') & (df['Actor2CountryCode'] == 'TWN')
is_tc = (df['Actor1CountryCode'] == 'TWN') & (df['Actor2CountryCode'] == 'CHN')

# ------------------------
# 2. 定义计算指标的函数（基于事件样本）
# ------------------------
def calc_metrics(sample):
    """
    sample: DataFrame某个月事件的子集（bootstrap抽样后的）
    返回：impact_index, tone_index, balance_index 三个指标
    """
    total_mentions = sample['NumMentions'].sum()
    total_events = len(sample)
    ct_events = is_ct.loc[sample.index].sum()

    # ImpactIndex 计算
    if total_mentions > 0:
        impact_raw = (sample['GoldsteinScale'] * sample['NumMentions']).sum()
        impact_index = (impact_raw / total_mentions + 10) / 20
    else:
        impact_index = np.nan

    # ToneIndex 计算
    if total_mentions > 0:
        tone_raw = (sample['AvgTone'] * sample['NumMentions']).sum()
        tone_index = (tone_raw / total_mentions + 10) / 20
    else:
        tone_index = np.nan

    # BalanceIndex 计算
    if total_events > 0:
        balance_index = ct_events / total_events
    else:
        balance_index = np.nan

    return impact_index, tone_index, balance_index


# ------------------------
# 3. 事件级 Bootstrap 计算
# ------------------------
n_bootstrap = 2000
months = df['Month'].unique()
months = pd.PeriodIndex(months).sort_values()

results = []

for month in months:
    month_data = df[df['Month'] == month]

    impact_boot = []
    tone_boot = []
    balance_boot = []

    for _ in range(n_bootstrap):
        sample = month_data.sample(n=len(month_data), replace=True)
        i_idx, t_idx, b_idx = calc_metrics(sample)
        impact_boot.append(i_idx)
        tone_boot.append(t_idx)
        balance_boot.append(b_idx)

    # 计算均值和置信区间
    def ci(vals):
        vals = [v for v in vals if not pd.isna(v)]
        if len(vals) == 0:
            return np.nan, np.nan, np.nan
        lower = np.percentile(vals, 2.5)
        upper = np.percentile(vals, 97.5)
        mean = np.mean(vals)
        return mean, lower, upper

    impact_mean, impact_lower, impact_upper = ci(impact_boot)
    tone_mean, tone_lower, tone_upper = ci(tone_boot)
    balance_mean, balance_lower, balance_upper = ci(balance_boot)

    results.append({
        'Month': month.to_timestamp(),
        'ImpactIndex_Mean': impact_mean,
        'ImpactIndex_Lower': impact_lower,
        'ImpactIndex_Upper': impact_upper,
        'ToneIndex_Mean': tone_mean,
        'ToneIndex_Lower': tone_lower,
        'ToneIndex_Upper': tone_upper,
        'BalanceIndex_Mean': balance_mean,
        'BalanceIndex_Lower': balance_lower,
        'BalanceIndex_Upper': balance_upper
    })

# 转成 DataFrame
out_df = pd.DataFrame(results)

# ------------------------
# 4. 保存结果 Excel
# ------------------------
out_df.to_excel("Bootstrap_Monthly_Indices_Full.xlsx", index=False)
print("已保存结果文件：Bootstrap_Monthly_Indices_Full.xlsx")

# ------------------------
# 5. 画图函数（带置信区间）
# ------------------------
def plot_with_ci(df, mean_col, lower_col, upper_col, title, ylabel, fname, color):
    fig, ax = plt.subplots(figsize=(12, 5))
    ax.plot(df['Month'], df[mean_col], '-o', color=color, label=title)
    ax.fill_between(df['Month'], df[lower_col], df[upper_col], color=color, alpha=0.25, label='95% CI')

    ax.set_title(title)
    ax.set_xlabel('Month')
    ax.set_ylabel(ylabel)
    ax.xaxis.set_major_locator(mdates.MonthLocator(interval=6))
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m'))
    plt.xticks(rotation=45)
    ax.legend()
    plt.tight_layout()
    plt.savefig(fname)
    plt.close(fig)
    print(f"已生成图像：{fname}")

# ------------------------
# 6. 绘制3张图
# ------------------------
plot_with_ci(out_df, 'ImpactIndex_Mean', 'ImpactIndex_Lower', 'ImpactIndex_Upper',
             'Normalized Impact Index', 'ImpactIndex [0-1]', 'ImpactIndex_with_CI.png', 'green')

plot_with_ci(out_df, 'ToneIndex_Mean', 'ToneIndex_Lower', 'ToneIndex_Upper',
             'Normalized Tone Index', 'ToneIndex [0-1]', 'ToneIndex_with_CI.png', 'red')

plot_with_ci(out_df, 'BalanceIndex_Mean', 'BalanceIndex_Lower', 'BalanceIndex_Upper',
             'Coverage Balance (CT_Events / TotalEvents)', 'BalanceIndex [0-1]', 'BalanceIndex_with_CI.png', 'orange')

print("绘图完成。")
